{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Yolo-v2 DW2TF validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# Random input data to test \n",
    "in_data = np.random.randn(64, 608, 608, 3) * 120.56"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output from Darkflow converted Yolo-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0730 11:42:36.966557 140635950143232 deprecation.py:323] From /scratch/anaconda3/envs/tf1.14/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "saver = tf.train.import_meta_graph('/data/darkflow/yolov2.ckpt.meta')\n",
    "ckpt_path = '/data/darkflow/yolov2.ckpt'\n",
    "\n",
    "g = tf.get_default_graph()\n",
    "in_t = g.get_tensor_by_name('input:0')\n",
    "out_t = g.get_tensor_by_name('output:0')\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, ckpt_path)\n",
    "    \n",
    "    out_data_darkflow = sess.run(out_t, feed_dict={in_t: in_data})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output from DW2TF converted Yolo-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "saver = tf.train.import_meta_graph('/data/DW2TF/data/yolov2.ckpt.meta')\n",
    "ckpt_path = '/data/DW2TF/data/yolov2.ckpt'\n",
    "\n",
    "g = tf.get_default_graph()\n",
    "in_t = g.get_tensor_by_name('yolov2/net1:0')\n",
    "out_t = g.get_tensor_by_name('yolov2/convolutional23/BiasAdd:0')\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, ckpt_path)\n",
    "    \n",
    "    out_data_dw2tf = sess.run(out_t, feed_dict={in_t: in_data})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RMS Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 19, 19, 425)\n",
      "(64, 19, 19, 425)\n",
      "5.125400967058772e-15\n"
     ]
    }
   ],
   "source": [
    "print(out_data_darkflow.shape)\n",
    "print(out_data_dw2tf.shape)\n",
    "print(np.mean(out_data_darkflow - out_data_dw2tf)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ -4.4880652   5.9635143  -4.9752836  -2.7621126 -20.442324   14.926774\n",
      "  18.849686   13.160991    3.2470438 -14.896617 ]\n"
     ]
    }
   ],
   "source": [
    "print(out_data_darkflow[0,0,0,:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ -4.4880676   5.963505   -4.9753103  -2.7621546 -20.442469   14.926693\n",
      "  18.849705   13.1609535   3.2470534 -14.8967   ]\n"
     ]
    }
   ],
   "source": [
    "print(out_data_dw2tf[0,0,0,:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
